package tsj.ai.alibaba.data.analysis.util;

import com.alibaba.cloud.ai.graph.GraphResponse;
import com.alibaba.cloud.ai.graph.OverAllState;
import com.alibaba.cloud.ai.graph.action.NodeAction;
import com.alibaba.cloud.ai.graph.streaming.FluxConverter;
import com.alibaba.cloud.ai.graph.streaming.StreamingOutput;
import org.springframework.ai.chat.model.ChatResponse;
import reactor.core.publisher.Flux;
import reactor.core.publisher.Mono;

import java.util.Map;
import java.util.function.Function;

/**
 * @author taoshujian
 * @version 1.0
 * @since 2025-11-14 14:29:02
 */
public class FluxUtil {
    private FluxUtil() {

    }

    /**
     * 级联两个具有前后关系的Flux
     * @param originFlux 第一个Flux
     * @param nextFluxFunc 根据第一个Flux的聚合结果生成第二个Flux
     * @param aggregator 聚合第一个Flux的所有数据
     * @param preFlux 在第一个Flux前添加的信息Flux
     * @param middleFlux 在第一个Flux和第二个Flux之间添加的信息Flux
     * @param endFlux 在第二个Flux后添加的信息Flux
     */
    public static <T, R> Flux<T> cascadeFlux(Flux<T> originFlux, Function<R, Flux<T>> nextFluxFunc,
                                             Function<Flux<T>, Mono<R>> aggregator, Flux<T> preFlux, Flux<T> middleFlux, Flux<T> endFlux) {
        // 缓存原始流避免重复执行
        Flux<T> cachedOrigin = originFlux.cache();

        // 聚合结果
        Mono<R> aggregatedResult = aggregator.apply(cachedOrigin).cache();

        // 构建完整流
        Flux<T> secondFlux = aggregatedResult.flatMapMany(nextFluxFunc);

        return preFlux.concatWith(cachedOrigin).concatWith(middleFlux).concatWith(secondFlux).concatWith(endFlux);
    }

    /**
     * 级联两个具有前后关系的Flux
     * @param originFlux 第一个Flux
     * @param nextFluxFunc 根据第一个Flux的聚合结果生成第二个Flux
     * @param aggregator 聚合第一个Flux的所有数据
     */
    public static <T, R> Flux<T> cascadeFlux(Flux<T> originFlux, Function<R, Flux<T>> nextFluxFunc,
                                             Function<Flux<T>, Mono<R>> aggregator) {
        return cascadeFlux(originFlux, nextFluxFunc, aggregator, Flux.empty(), Flux.empty(), Flux.empty());
    }

    /**
     * Quickly create streaming generator with start and end messages
     *
     * 快速创建带有开始和结束消息的流媒体生成器
     *
     * @param nodeClass node class
     * @param state state
     * @param startMessage start message
     * @param completionMessage completion message
     * @param resultMapper result mapping function
     * @param sourceFlux source data stream
     * @return AsyncGenerator instance
     */
    public static Flux<GraphResponse<StreamingOutput>> createStreamingGeneratorWithMessages(
            Class<? extends NodeAction> nodeClass, OverAllState state, String startMessage, String completionMessage,
            Function<String, Map<String, Object>> resultMapper, Flux<ChatResponse> sourceFlux) {
        String nodeName = nodeClass.getSimpleName();

        // Used to collect actual processing results
        final StringBuilder collectedResult = new StringBuilder();

        // wrapperFlux
        Flux<ChatResponse> startFlux = (startMessage == null ? Flux.empty()
                : Flux.just(ChatResponseUtil.createResponse(startMessage)));
        Flux<ChatResponse> wrapperFlux = startFlux.concatWith(sourceFlux.doOnNext(chatResponse -> {
            String text = ChatResponseUtil.getText(chatResponse);
            collectedResult.append(text);
        }));
        if (completionMessage != null) {
            wrapperFlux = wrapperFlux.concatWith(Flux.just(ChatResponseUtil.createResponse(completionMessage)));
        }
        return FluxConverter.builder()
                .startingNode(nodeName)
                .startingState(state)
                .mapResult(r -> resultMapper.apply(collectedResult.toString()))
                .build(wrapperFlux);
    }

    public static Flux<GraphResponse<StreamingOutput>> createStreamingGeneratorWithMessages(
            Class<? extends NodeAction> nodeClass, OverAllState state,
            Function<String, Map<String, Object>> resultMapper, Flux<ChatResponse> sourceFlux) {
        return createStreamingGeneratorWithMessages(nodeClass, state, null, null, resultMapper, sourceFlux);
    }
}
